package com.hj.aiagent.rag;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.stereotype.Component;

/**
 * @author: hj
 * @description: 查询重写器
 * @date: 2025/7/4 9:10
 */
@Component
public class QueryRewriter {

    private final QueryTransformer queryTransformer;

    QueryRewriter(ChatModel dashscopeChatModel) {
        ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
        //创建查询重写器
        this.queryTransformer = RewriteQueryTransformer.builder().chatClientBuilder(builder).build();
    }

    /**
     * （执行查询重写）
     *
     * @param prompt 提示
     * @return 字符串
     */
    public String doQueryRewrite(String prompt) {
        Query query = new Query(prompt);
        //重写查询
        Query transform = queryTransformer.transform(query);
        //输出重写后的查询
        return transform.text();
    }
}
